{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "3d3e8316-6769-44e1-b522-cb4b35fc4541",
   "metadata": {},
   "source": [
    "<a href=\"https://colab.research.google.com/github/jalammar/ecco/blob/main/notebooks/Ecco_CCA_Similarity.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7c05bfa8",
   "metadata": {},
   "outputs": [],
   "source": [
    "! pip install ecco"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "55290353-1c0f-4778-abd8-bb4c09969e1a",
   "metadata": {},
   "source": [
    "Load Ecco and BERT."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "228061f6-2cfc-47ea-9357-6789c81745d1",
   "metadata": {
    "jupyter": {
     "outputs_hidden": true
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/ola/ecco/venv/lib/python3.7/site-packages/tensorflow/python/framework/dtypes.py:526: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_qint8 = np.dtype([(\"qint8\", np.int8, 1)])\n",
      "/home/ola/ecco/venv/lib/python3.7/site-packages/tensorflow/python/framework/dtypes.py:527: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_quint8 = np.dtype([(\"quint8\", np.uint8, 1)])\n",
      "/home/ola/ecco/venv/lib/python3.7/site-packages/tensorflow/python/framework/dtypes.py:528: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_qint16 = np.dtype([(\"qint16\", np.int16, 1)])\n",
      "/home/ola/ecco/venv/lib/python3.7/site-packages/tensorflow/python/framework/dtypes.py:529: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_quint16 = np.dtype([(\"quint16\", np.uint16, 1)])\n",
      "/home/ola/ecco/venv/lib/python3.7/site-packages/tensorflow/python/framework/dtypes.py:530: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_qint32 = np.dtype([(\"qint32\", np.int32, 1)])\n",
      "/home/ola/ecco/venv/lib/python3.7/site-packages/tensorflow/python/framework/dtypes.py:535: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  np_resource = np.dtype([(\"resource\", np.ubyte, 1)])\n",
      "Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_projector.weight', 'vocab_layer_norm.weight', 'vocab_transform.bias', 'vocab_layer_norm.bias', 'vocab_projector.bias', 'vocab_transform.weight']\n",
      "- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
     ]
    }
   ],
   "source": [
    "import ecco\n",
    "lm = ecco.from_pretrained('distilbert-base-uncased', gpu=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "27d1f126-e938-42f6-8e8c-fb97f79a2b74",
   "metadata": {},
   "source": [
    "Let's give BERT a passage of text to proccess"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "b1abd436-50bf-4722-b113-73eaa795020d",
   "metadata": {},
   "outputs": [],
   "source": [
    "text = '''Now I ask you: what can be expected of man since he is a being endowed with strange qualities? Shower upon him every earthly blessing, drown him in a sea of happiness, so that nothing but bubbles of bliss can be seen on the surface; give him economic prosperity, such that he should have nothing else to do but sleep, eat cakes and busy himself with the continuation of his species, and even then out of sheer ingratitude, sheer spite, man would play you some nasty trick. He would even risk his cakes and would deliberately desire the most fatal rubbish, the most uneconomical absurdity, simply to introduce into all this positive good sense his fatal fantastic element. It is just his fantastic dreams, his vulgar folly that he will desire to retain, simply in order to prove to himself--as though that were so necessary-- that men still are men and not the keys of a piano, which the laws of nature threaten to control so completely that soon one will be able to desire nothing but by the calendar. And that is not all: even if man really were nothing but a piano-key, even if this were proved to him by natural science and mathematics, even then he would not become reasonable, but would purposely do something perverse out of simple ingratitude, simply to gain his point. And if he does not find means he will contrive destruction and chaos, will contrive sufferings of all sorts, only to gain his point! He will launch a curse upon the world, and as only man can curse (it is his privilege, the primary distinction between him and other animals), may be by his curse alone he will attain his object--that is, convince himself that he is a man and not a piano-key!\n",
    "'''\n",
    "\n",
    "inputs = lm.tokenizer([text], return_tensors=\"pt\")\n",
    "output = lm(inputs)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "470b6fa8-1f45-42e6-afa8-ba6fbe576bb4",
   "metadata": {},
   "source": [
    "the `output` variable now contains the result of BERT processing the passge of text. The property `output.decoder_hidden_states` contains the hidden states after each layer."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "1c97a987-c239-4f51-8f39-f4f00dfa9464",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((768, 363), (768, 363), 6)"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "hidden_states = output._get_encoder_hidden_states()\n",
    "embed = output.embedding_states.detach().numpy()[0,:,:].T\n",
    "hidden_state_layer = [layer.detach().numpy()[0,:,:].T for layer in hidden_states]\n",
    "embed.shape, hidden_state_layer[0].shape, len(hidden_state_layer)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "41ee4085-a697-4753-98b8-f930b4abcad0",
   "metadata": {},
   "source": [
    "`embed` now contains the embeddings of the inputs. Its dimensions are (embed_dim, number of tokens). \n",
    "`hidden_state_layer` has the outputs of each of the model's 6 layers. The output of each layer is (embed_dim, number of tokens).\n",
    "\n",
    "This is how to calculate the cka similarity score between the embeddings layer and the output of the first layer:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "9fe8fa2b-39b7-4a04-bf4e-c62cffe3f2ff",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.9042735255823328"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from ecco import analysis\n",
    "analysis.cka(embed, hidden_state_layer[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "831da973-63bc-4a53-8022-8017ee82af57",
   "metadata": {},
   "source": [
    "When we compare the embeddings with the output of the second layer, we see less similarity"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "d428f519-fb48-402d-8bea-962783fe36de",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.7774274463814453"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "analysis.cka(embed, hidden_state_layer[1])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "823e071c-ccc2-458e-b792-8d451bd48e41",
   "metadata": {},
   "source": [
    "And so on"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "8621e099-a768-4e17-86e6-6281c4fe4a0f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.6922863613160068"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "analysis.cka(embed, hidden_state_layer[2])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7c2de6f2-fe10-4672-ae51-cfdfe2382be2",
   "metadata": {},
   "source": [
    "We can try with `cca`, `svcca` and `pwcca`. But we need to choose a subset of the neurons because these methods require more tokens than neurons (and advise 10x as many tokens as neurons to get a proper similarity score). \n",
    "\n",
    "Let's compare the similarities of the first 50 neurons."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "fc3f50a6-eaa0-498e-896a-11d333d7fb5f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CCA - Embed vs. layer 0: 0.8518187688700852\n",
      "CCA - Embed vs. layer 1: 0.7220358158064838\n"
     ]
    }
   ],
   "source": [
    "print(\"CCA - Embed vs. layer 0:\", analysis.cca(embed[:50,:], hidden_state_layer[0][:50,:]))\n",
    "print(\"CCA - Embed vs. layer 1:\", analysis.cca(embed[:50,:], hidden_state_layer[1][:50,:]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "ebb5d691-af03-4d0c-9282-c5338671df12",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "SVCCA - Embed vs. layer 0: 0.7830642779004243\n",
      "SVCCA - Embed vs. layer 1: 0.6833412966387719\n"
     ]
    }
   ],
   "source": [
    "print(\"SVCCA - Embed vs. layer 0:\", analysis.svcca(embed[:50,:], hidden_state_layer[0][:50,:]))\n",
    "print(\"SVCCA - Embed vs. layer 1:\", analysis.svcca(embed[:50,:], hidden_state_layer[1][:50,:]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "5b3e47b3-e5a6-47ce-8964-29eb086e29ad",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "PWCCA - Embed vs. layer 0: 0.8695735290407357\n",
      "PWCCA - Embed vs. layer 1: 0.7461958851582353\n"
     ]
    }
   ],
   "source": [
    "print(\"PWCCA - Embed vs. layer 0:\", analysis.pwcca(embed[:50,:], hidden_state_layer[0][:50,:]))\n",
    "print(\"PWCCA - Embed vs. layer 1:\", analysis.pwcca(embed[:50,:], hidden_state_layer[1][:50,:]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "27e4e49a",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
